import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as td
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm, trange

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
    
    
   
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Fill tensor with uniform values from [l, u]
        tensor.uniform_(l, u)

        # Use inverse cdf transform from normal distribution
        tensor.mul_(2)
        tensor.sub_(1)

        # Ensure that the values are strictly between -1 and 1 for erfinv
        eps = torch.finfo(tensor.dtype).eps
        tensor.clamp_(min=-(1. - eps), max=(1. - eps))
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp one last time to ensure it's still in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

class Gaussian(nn.Module):
    def __init__(self, mu, rho, device='cuda', fixed=False):
        super().__init__()
        self.mu = nn.Parameter(mu, requires_grad=not fixed)
        self.rho = nn.Parameter(rho, requires_grad=not fixed)
        self.device = device

    @property
    def sigma(self):
        # Computation of standard deviation:
        # We use rho instead of sigma so that sigma is always positive during
        # the optimisation. Specifically, we use sigma = log(exp(rho)+1)
        return torch.log(1 + torch.exp(self.rho))
        

    def sample(self):
        # Return a sample from the Gaussian distribution
        epsilon = torch.randn(self.sigma.size()).to(self.device)
        return self.mu + self.sigma * epsilon

    def compute_kl(self, other):
        # Compute KL divergence between two Gaussians (self and other)
        # (refer to the paper)
        # b is the variance of priors
        b1 = torch.pow(self.sigma, 2)
        b0 = torch.pow(other.sigma, 2)

        term1 = torch.log(torch.div(b0, b1))
        term2 = torch.div(
            torch.pow(self.mu - other.mu, 2), b0)
        term3 = torch.div(b1, b0)
        kl_div = (torch.mul(term1 + term2 + term3 - 1, 0.5)).sum()
        return kl_div

class ProbLinear(nn.Module):
    def __init__(self, in_features, out_features, rho_prior, prior_dist='gaussian', device='cuda', init_prior='weights', init_layer=None, init_layer_prior=None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        weights_mu_init = init_layer
        bias_mu_init = torch.zeros(out_features) + 0.001

        weights_rho_init = torch.ones(out_features, in_features) * rho_prior
        bias_rho_init = torch.ones(out_features) * rho_prior

        weights_mu_prior = weights_mu_init
        bias_mu_prior = bias_mu_init
    
        dist = Gaussian

        self.bias = dist(bias_mu_init.clone(),
                         bias_rho_init.clone(), device=device, fixed=False)
        self.weight = dist(weights_mu_init.clone(),
                           weights_rho_init.clone(), device=device, fixed=False)
        self.weight_prior = dist(
            weights_mu_prior.clone(), weights_rho_init.clone(), device=device, fixed=True)
        self.bias_prior = dist(
            bias_mu_prior.clone(), bias_rho_init.clone(), device=device, fixed=True)

        self.kl_div = 0

    def forward(self, input, sample=False):
        if self.training or sample:
            # during training we sample from the model distribution
            # sample = True can also be set during testing if we
            # want to use the stochastic/ensemble predictors
            weight = self.weight.sample()
            bias = self.bias.sample()
        else:
            # otherwise we use the posterior mean
            weight = self.weight.mu
            bias = self.bias.mu
        if self.training:
            # sum of the KL computed for weights and biases
            self.kl_div = self.weight.compute_kl(self.weight_prior) + \
                self.bias.compute_kl(self.bias_prior)

        return F.linear(input, weight, bias)



class ProbNNet4l(nn.Module):

    def __init__(self, rho_prior, prior_dist='gaussian', device='cuda', init_net=None,Net_Width= 600):
        super().__init__()
        self.l1 = ProbLinear(28*28, Net_Width, rho_prior, prior_dist=prior_dist,
                             device=device, init_layer=init_net[0] if init_net else None)
        self.l2 = ProbLinear(Net_Width, Net_Width, rho_prior, prior_dist=prior_dist,
                             device=device, init_layer=init_net[1] if init_net else None)
        self.l3 = ProbLinear(Net_Width, Net_Width, rho_prior, prior_dist=prior_dist,
                             device=device, init_layer=init_net[2] if init_net else None)
        self.l4 = ProbLinear(Net_Width, 10, rho_prior, prior_dist=prior_dist,
                             device=device, init_layer=init_net[3] if init_net else None)
        self.Net_Width = Net_Width
    def forward(self, x, sample=False, clamping=True, pmin=1e-4):
        x = x.view(-1, 28*28)
        x = 1/np.sqrt(28*28)*F.relu(self.l1(x, sample))
        x = 1/np.sqrt(self.Net_Width)*F.relu(self.l2(x, sample))
        x = 1/np.sqrt(self.Net_Width)*F.relu(self.l3(x, sample))
        x = 1/np.sqrt(self.Net_Width)*self.l4(x, sample)
        return x

    def compute_kl(self):
        # KL as a sum of the KL for each individual layer
        return self.l1.kl_div + self.l2.kl_div + self.l3.kl_div + self.l4.kl_div


def output_transform(x, clamping=True, pmin=1e-4):

    output = F.log_softmax(x, dim=1)
    if clamping:
        output = torch.clamp(output, np.log(pmin))
    return output

def one_hot(x, class_count):
    return torch.eye(class_count)[x,:]

def trainPNNet(net, optimizer, pbobj, epoch, train_loader, lambda_var=None, optimizer_lambda=None, verbose=False):
    net.train()
    # variables that keep information about the results of optimising the bound
    #avgerr, avgbound, avgkl, avgloss = 0.0, 0.0, 0.0, 0.0
    '''
    if pbobj.objective == 'flamb':
        lambda_var.train()
        # variables that keep information about the results of optimising lambda (only for flamb)
        avgerr_l, avgbound_l, avgkl_l, avgloss_l = 0.0, 0.0, 0.0, 0.0

    if pbobj.objective == 'bbb':
        clamping = False
    else:
        clamping = True
    '''
    clamping = True
    
    for batch_id, (data, target) in enumerate(tqdm(train_loader)):
        target = one_hot(target,10)
        data, target = data.to(pbobj.device), target.to(pbobj.device)
        net.zero_grad()
        bound, kl, _, loss, err = pbobj.train_obj(
            net, data, target, lambda_var=lambda_var, clamping=clamping)
        bound.backward()
        optimizer.step()
        #avgbound += bound.item()
        #avgkl += kl
        #avgloss += loss.item()
        #avgerr += err
    '''      
    batch_id=1
    if verbose:
        # show the average of the metrics during the epoch
        print(
            f"-Batch average epoch {epoch :.0f} results, Train obj: {avgbound/batch_id :.5f}, KL/n: {avgkl/batch_id :.5f}, NLL loss: {avgloss/batch_id :.5f}, Train 0-1 Error:  {avgerr/batch_id :.5f}")
        
    '''




